Skip to content

Conversation

@mydatascience
Copy link
Collaborator

Description

Refactoring of grpo. Adding new unified functionality allowing to add models easily

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- `grpo_llama3_1_8b_demo_pw.py` - Pathways-based 8B model
- `grpo_llama3_1_70b_demo_pw.py` - Pathways-based 70B model

These have been consolidated into a single **unified CLI script** (`grpo_demo.py`) that works with the new **grpo.yml** configuration file.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again - should be "demo"?

to me, demo indicates it may not be suitable for production workloads

@github-actions
Copy link

🤖 Hi @A9isha, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move evaluate_rl.py, rl_utils.py and train_rl.py inside src/MaxText/rl?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

# ====== Debug flag for verbose logs ======
DEBUG = tmvp_config.debug

print("Starting GRPO Training")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use max_logging.log

os.makedirs(data_dir)

data = tfds.data_source(
"gsm8k",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset name should come from tmvp_config.


# ====== Data ======
# Setup data directories
home = os.path.expanduser("~") + "/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we switch to Hugging Face API to load dataset, we won't need to setup these data directories. Hugging Face would download the data in a cache. Not sure if TFDS can also do that.

os.makedirs(test_data_dir)

# Create model tokenizer
model_tokenizer = AutoTokenizer.from_pretrained(tmvp_config.hf_model_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base.yml has tokenizer_path that can be used here, right?

Copy link
Collaborator

@xuefgu xuefgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @A9isha !

In addition to the comments - can you please clarify what tests you have performed, i.e. the hardware, model, and more importantly with what configs (since that's the main change here)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +52 to +112
# # Install vLLM for Jax and TPUs from the artifact registry
# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
# --extra-index-url https://pypi.org/simple/ \
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu

# # Install tpu-commons from the artifact registry
# RUN pip install --no-cache-dir --pre \
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
# --extra-index-url https://pypi.org/simple/ \
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
# tpu-commons==0.1.2

# # Uninstall existing jax to avoid conflicts
# # RUN pip uninstall -y jax jaxlib libtpu

# # --- STAGE 1: Install Static Dependencies ---
# # Install any packages *not* defined in your project dependency files
# RUN --mount=type=cache,target=/root/.cache/pip pip install \
# aiohttp==3.12.15\
# keyring \
# keyrings.google-artifactregistry-auth

# RUN --mount=type=cache,target=/root/.cache/pip pip install \
# numba==0.61.2

# # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---

# # Copy *only* the dependency definition files.
# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
# COPY vllm/requirements/tpu.txt /tmp/
# COPY vllm/requirements/build.txt /tmp/
# COPY vllm/requirements/common.txt /tmp/
# COPY tpu-inference/requirements.txt /tmp/

# # Run the full dependency installation.
# # This entire layer is cached and will *only* be rebuilt if
# # these .txt files change.
# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
# # Set the target device so pip installs the right JAX/libtpu
# # Install tpu-inference dependencies
# export VLLM_TARGET_DEVICE="tpu" && \
# pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
# --extra-index-url https://pypi.org/simple/ \
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up this large block of commented code. Some of them is no longer relevant.

os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
)

tmvp_config = pyconfig.initialize(argv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use config as the variable name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tmvp_configs helps articulate that we have all configs together here

Comment on lines +129 to +142
num_trainer_devices = int(num_devices * tmvp_config.trainer_devices_fraction)
num_sampler_devices = int(num_devices * tmvp_config.sampler_devices_fraction)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In config, should we add a check for "if using pathways, trainer_devices_fraction + sampler_devices_fraction should not exceed 1"? I find the behavior hard to reason about if the sum is larger than 1 for disaggregated RL.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do allow trainer_devices_fraction = sampler_devices_fraction = 1.0 where we use the full mesh for both training and inference, i.e., without disaggregate but still multihost

# Load policy model
print("Creating policy model with same config as reference model on trainer mesh")
policy_model, policy_mesh = get_maxtext_model(tmvp_config, trainer_devices)
actor_mesh = policy_mesh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all intents and purposes, we don't need actor_mesh and can just use policy_mesh in line 317.

Or, we could call the vars actor_model and actor_mesh in line 262. I prefer this.

The point is that there is no material difference between "actor" and "policy" in the context, so distinguishing them is quite confusing.


# ====== System prompt and Templates ======

system_prompt: |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dataset specific and can be moved to examples script.
I recently added a templates folder in MaxText, where we have this template for GSM8K dataset. We can use that.

from MaxText import rl_utils


# We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment at line 46 mentions that this file can be used to run GRPO on a custom dataset too. Can we move all GSM8K related stuff to examples?

mydatascience and others added 20 commits October 31, 2025 22:09
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants